/** * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.deephacks.westty.protobuf; import com.google.common.base.Strings; import com.google.protobuf.DescriptorProtos.FileDescriptorProto; import com.google.protobuf.DescriptorProtos.FileDescriptorSet; import com.google.protobuf.Descriptors.Descriptor; import com.google.protobuf.Descriptors.FieldDescriptor; import com.google.protobuf.Descriptors.FileDescriptor; import com.google.protobuf.Message; import org.deephacks.westty.protobuf.FailureMessages.Failure; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import javax.enterprise.inject.Alternative; import java.io.File; import java.io.FileInputStream; import java.io.FileNotFoundException; import java.io.IOException; import java.io.InputStream; import java.lang.reflect.Method; import java.net.URL; import java.nio.ByteBuffer; import java.util.HashMap; import static org.deephacks.westty.protobuf.FailureMessageException.FailureCode.BAD_REQUEST; @Alternative public class ProtobufSerializer { private static final Logger log = LoggerFactory.getLogger(ProtobufSerializer.class); private HashMap<Integer, Method> numToMethod = new HashMap<>(); private HashMap<String, Integer> protoToNum = new HashMap<>(); private static final String UNRECOGNIZED_PROTOCOL_MSG = "Unrecognized protocol."; public ProtobufSerializer() { registerResource("META-INF/failure.desc"); registerResource("META-INF/void.desc"); } public void register(URL protodesc) { try { registerDesc(protodesc.getFile(), protodesc.openStream()); } catch (IOException e) { throw new RuntimeException(e); } } public void register(File protodesc) { try { registerDesc(protodesc.getName(), new FileInputStream(protodesc)); } catch (FileNotFoundException e) { throw new RuntimeException(e); } } public void registerResource(String protodesc) { URL url = Thread.currentThread().getContextClassLoader().getResource(protodesc); register(url); } private void registerDesc(String name, InputStream in) { try { FileDescriptorSet descriptorSet = FileDescriptorSet.parseFrom(in); for (FileDescriptorProto fdp : descriptorSet.getFileList()) { FileDescriptor fd = FileDescriptor.buildFrom(fdp, new FileDescriptor[] {}); for (Descriptor desc : fd.getMessageTypes()) { FieldDescriptor fdesc = desc.findFieldByName("protoType"); if (fdesc == null) { throw new IllegalArgumentException(name + ".proto file must define protoType field " + "with unqiue number that identify proto type"); } String packageName = fdp.getOptions().getJavaPackage(); if (Strings.isNullOrEmpty(packageName)) { throw new IllegalArgumentException(name + ".proto file must define java_package"); } String simpleClassName = fdp.getOptions().getJavaOuterClassname(); if (Strings.isNullOrEmpty(simpleClassName)) { throw new IllegalArgumentException(name + " .proto file must define java_outer_classname"); } String className = packageName + "." + simpleClassName + "$" + desc.getName(); Class<?> cls = Thread.currentThread().getContextClassLoader() .loadClass(className); protoToNum.put(desc.getFullName(), fdesc.getNumber()); numToMethod.put(fdesc.getNumber(), cls.getMethod("parseFrom", byte[].class)); log.debug("Registered protobuf resource {}.", name); } } } catch (Exception e) { throw new RuntimeException(e); } } public Object read(byte[] bytes) throws Exception { try { ByteBuffer buf = ByteBuffer.wrap(bytes); Varint32 vint = new Varint32(buf); int protoTypeNum = vint.read(); buf = vint.getByteBuffer(); byte[] message = new byte[buf.remaining()]; buf.get(message); Method m = numToMethod.get(protoTypeNum); if (m == null) { return Failure.newBuilder().setCode(BAD_REQUEST.getCode()) .setMsg("proto_type=" + protoTypeNum).build(); } return m.invoke(null, message); } catch (Exception e) { return Failure.newBuilder().setCode(BAD_REQUEST.getCode()) .setMsg(UNRECOGNIZED_PROTOCOL_MSG).build(); } } public byte[] write(Object proto) throws IOException { Message msg = (Message) proto; String protoName = msg.getDescriptorForType().getFullName(); Integer num = protoToNum.get(protoName); if(num == null){ throw new IllegalArgumentException("Could not find protoType mapping for " + protoName); } byte[] msgBytes = msg.toByteArray(); Varint32 vint = new Varint32(num); int vsize = vint.getSize(); byte[] bytes = new byte[vsize + msgBytes.length]; System.arraycopy(vint.write(), 0, bytes, 0, vsize); System.arraycopy(msgBytes, 0, bytes, vsize, msgBytes.length); return bytes; } }